import argparse
import logging
import os
import sys
import random
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn

from datetime import datetime
import pytz

from lib.networks import EMCADNet
from trainer import trainer_synapse

import gc

if __name__ == '__main__':
    gc.collect()
    torch.cuda.empty_cache()

    print('Start training...')

    parser = argparse.ArgumentParser()

    parser.add_argument('--root_path', type=str,default='/path/to/train_npz_new', help='root dir for data') # TODO: replace with actual path
    parser.add_argument('--volume_path', type=str,default='/path/to/test_vol_h5_new', help='root dir for validation volume data') # TODO: replace with actual path
    parser.add_argument('--dataset', type=str,default='Synapse', help='experiment_name')
    parser.add_argument('--list_dir', type=str,default='./lists/lists_Synapse', help='list dir')
    parser.add_argument('--num_classes', type=int,default=9, help='output channel of network')

    # network related parameters
    parser.add_argument('--encoder', type=str,
                        default='pvt_v2_b2', help='Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...')
    parser.add_argument('--expansion_factor', type=int,
                        default=2, help='expansion factor in MSCB block')
    parser.add_argument('--kernel_sizes', type=int, nargs='+',
                        default=[1, 3, 5], help='multi-scale kernel sizes in MSDC block')
    parser.add_argument('--lgag_ks', type=int,
                        default=3, help='Kernel size in LGAG')
    parser.add_argument('--activation_mscb', type=str,
                        default='relu6', help='activation used in MSCB: relu6 or relu')
    parser.add_argument('--no_dw_parallel', action='store_true', 
                        default=False, help='use this flag to disable depth-wise parallel convolutions')
    parser.add_argument('--concatenation', action='store_true', 
                        default=False, help='use this flag to concatenate feature maps in MSDC block')
    parser.add_argument('--no_pretrain', action='store_true', 
                        default=False, help='use this flag to turn off loading pretrained enocder weights')
    parser.add_argument('--supervision', type=str,
                        default='mutation', help='loss supervision: mutation, deep_supervision or last_layer')

    parser.add_argument('--max_iterations', type=int,
                        default=50000, help='maximum epoch number to train')
    parser.add_argument('--max_epochs', type=int,
                        default=300, help='maximum epoch number to train')
    parser.add_argument('--batch_size', type=int,
                        default=6, help='batch_size per gpu')
    parser.add_argument('--base_lr', type=float,  default=0.0001,
                        help='segmentation network learning rate')
    parser.add_argument('--img_size', type=int,
                        default=224, help='input patch size of network input')
    parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
    parser.add_argument('--deterministic', type=int,  default=1,
                        help='whether use deterministic training')
    parser.add_argument('--seed', type=int,
                        default=2222, help='random seed')
    parser.add_argument('--dual', default= True, action='store_true', help='dual supervision or single supervision')

    args = parser.parse_args()

    if not args.deterministic:
        cudnn.benchmark = True
        cudnn.deterministic = False
    else:
        cudnn.benchmark = False
        cudnn.deterministic = True

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    dataset_name = args.dataset
    dataset_config = {
        'Synapse': {
            'root_path': args.root_path,
            'volume_path': args.volume_path,
            'list_dir': args.list_dir,
            'num_classes': args.num_classes,
            'z_spacing': 1,
        },
    }
    args.num_classes = dataset_config[dataset_name]['num_classes']
    args.root_path = dataset_config[dataset_name]['root_path']
    args.volume_path = dataset_config[dataset_name]['volume_path']
    args.z_spacing = dataset_config[dataset_name]['z_spacing']
    args.list_dir = dataset_config[dataset_name]['list_dir']

    if args.concatenation:
        aggregation = 'concat'
    else: 
        aggregation = 'add'

    if args.no_dw_parallel:
        dw_mode = 'series'
    else: 
        dw_mode = 'parallel'
        
    run = 1
    args.exp = args.encoder + '_EMCAD_kernel_sizes_' + str(args.kernel_sizes) + '_dw_' + dw_mode + '_' + aggregation + '_lgag_ks_' + str(args.lgag_ks) + '_ef' + str(args.expansion_factor) + '_act_mscb_' + args.activation_mscb + '_loss_' + args.supervision + '_output_final_layer_Run'+str(run)+'_' + dataset_name + str(args.img_size)
    if args.dual:
        args.exp = 'dual_'+args.exp
        model = EMCADNet(num_classes=args.num_classes, kernel_sizes=args.kernel_sizes, expansion_factor=args.expansion_factor, dw_parallel=not args.no_dw_parallel, add=not args.concatenation, lgag_ks=args.lgag_ks, activation=args.activation_mscb, encoder=args.encoder, pretrain= not args.no_pretrain, dual=True).cuda()
    else:
        model = EMCADNet(num_classes=args.num_classes, kernel_sizes=args.kernel_sizes, expansion_factor=args.expansion_factor, dw_parallel=not args.no_dw_parallel, add=not args.concatenation, lgag_ks=args.lgag_ks, activation=args.activation_mscb, encoder=args.encoder, pretrain= not args.no_pretrain).cuda()
        
    current_time = datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y-%m-%d %H")
    print("The current time is", current_time)
    args.exp='run:'+current_time+'_'+args.exp


    snapshot_path = "./model_pth/Synapse/{}".format(args.exp)
    snapshot_path = snapshot_path + '_pretrain' if not args.no_pretrain else snapshot_path
    snapshot_path = snapshot_path+'_'+str(args.max_iterations)[0:2]+'k' if args.max_iterations != 50000 else snapshot_path
    snapshot_path = snapshot_path + '_epo' +str(args.max_epochs) if args.max_epochs != 300 else snapshot_path
    snapshot_path = snapshot_path+'_bs'+str(args.batch_size)
    snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.0001 else snapshot_path
    snapshot_path = snapshot_path + '_'+str(args.img_size)
    snapshot_path = snapshot_path + '_s'+str(args.seed) if args.seed!=1234 else snapshot_path

    print('snapshot_path: ', snapshot_path)
    if not os.path.exists(snapshot_path):
        os.makedirs(snapshot_path)
            

    print('Model successfully created.')

    trainer = {'Synapse': trainer_synapse,}
    trainer[dataset_name](args, model, snapshot_path)
